{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Bert实现命名实体识别任务"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 1.加载数据\n",
    "加载数据以及数据的展示，这里使用最常见的conll2003数据集进行实验"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "task = \"ner\"  # Should be one of \"ner\", \"pos\" or \"chunk\"\n",
    "model_checkpoint = \"distilbert-base-uncased\"\n",
    "batch_size = 16\n",
    "from datasets import load_dataset, load_metric,Dataset\n",
    "\n",
    "datasets = load_dataset(\"conll2003\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "展示数据集的第一条数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'id': '0',\n",
       " 'tokens': ['EU',\n",
       "  'rejects',\n",
       "  'German',\n",
       "  'call',\n",
       "  'to',\n",
       "  'boycott',\n",
       "  'British',\n",
       "  'lamb',\n",
       "  '.'],\n",
       " 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],\n",
       " 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],\n",
       " 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets[\"train\"].features[f\"ner_tags\"].feature.names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['O',\n",
       " 'B-ADJP',\n",
       " 'I-ADJP',\n",
       " 'B-ADVP',\n",
       " 'I-ADVP',\n",
       " 'B-CONJP',\n",
       " 'I-CONJP',\n",
       " 'B-INTJ',\n",
       " 'I-INTJ',\n",
       " 'B-LST',\n",
       " 'I-LST',\n",
       " 'B-NP',\n",
       " 'I-NP',\n",
       " 'B-PP',\n",
       " 'I-PP',\n",
       " 'B-PRT',\n",
       " 'I-PRT',\n",
       " 'B-SBAR',\n",
       " 'I-SBAR',\n",
       " 'B-UCP',\n",
       " 'I-UCP',\n",
       " 'B-VP',\n",
       " 'I-VP']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets[\"train\"].features[f\"chunk_tags\"].feature.names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['\"',\n",
       " \"''\",\n",
       " '#',\n",
       " '$',\n",
       " '(',\n",
       " ')',\n",
       " ',',\n",
       " '.',\n",
       " ':',\n",
       " '``',\n",
       " 'CC',\n",
       " 'CD',\n",
       " 'DT',\n",
       " 'EX',\n",
       " 'FW',\n",
       " 'IN',\n",
       " 'JJ',\n",
       " 'JJR',\n",
       " 'JJS',\n",
       " 'LS',\n",
       " 'MD',\n",
       " 'NN',\n",
       " 'NNP',\n",
       " 'NNPS',\n",
       " 'NNS',\n",
       " 'NN|SYM',\n",
       " 'PDT',\n",
       " 'POS',\n",
       " 'PRP',\n",
       " 'PRP$',\n",
       " 'RB',\n",
       " 'RBR',\n",
       " 'RBS',\n",
       " 'RP',\n",
       " 'SYM',\n",
       " 'TO',\n",
       " 'UH',\n",
       " 'VB',\n",
       " 'VBD',\n",
       " 'VBG',\n",
       " 'VBN',\n",
       " 'VBP',\n",
       " 'VBZ',\n",
       " 'WDT',\n",
       " 'WP',\n",
       " 'WP$',\n",
       " 'WRB']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets[\"train\"].features[f\"pos_tags\"].feature.names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**tokenizer不仅可以 对 字符串进行序列化 还可以对 分词后的token进行序列化 需要设置 is_split_into_words=True**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 7592, 1010, 2023, 2003, 2028, 6251, 3975, 2046, 2616, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer([\"Hello\", \",\", \"this\", \"is\", \"one\", \"sentence\", \"split\", \"into\", \"words\", \".\"], is_split_into_words=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] hello, this is one sentence split into words. [SEP]'"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode([101, 7592, 1010, 2023, 2003, 2028, 6251, 3975, 2046, 2616, 1012, 102])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**值得注意的是 tokenizer可能将 单词 分割成  单词的 词根或词缀 即  经过 tokenizer后 序列的长度可能发生改变<br>\n",
    "Transformers通常使用子词标记器进行预训练，这意味着即使您的输入已经被分割成单词，这些单词中的每一个都可以被标记器再次分割。让我们看一个例子:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原始token: ['Germany', \"'s\", 'representative', 'to', 'the', 'European', 'Union', \"'s\", 'veterinary', 'committee', 'Werner', 'Zwingmann', 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.']\n",
      "----------------------------------------------------------------------------------------------------\n",
      "转换后的 token: ['[CLS]', 'germany', \"'\", 's', 'representative', 'to', 'the', 'european', 'union', \"'\", 's', 'veterinary', 'committee', 'werner', 'z', '##wing', '##mann', 'said', 'on', 'wednesday', 'consumers', 'should', 'buy', 'sheep', '##me', '##at', 'from', 'countries', 'other', 'than', 'britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.', '[SEP]']\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"[CLS] germany's representative to the european union's veterinary committee werner zwingmann said on wednesday consumers should buy sheepmeat from countries other than britain until the scientific advice was clearer. [SEP]\""
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example = datasets[\"train\"][4]\n",
    "print(\"原始token:\",example[\"tokens\"])\n",
    "print(\"-\"*100)\n",
    "tokenized_input = tokenizer(example[\"tokens\"], is_split_into_words=True)\n",
    "tokens = tokenizer.convert_ids_to_tokens(tokenized_input[\"input_ids\"])\n",
    "print(\"转换后的 token:\",tokens)\n",
    "print(\"-\"*100)\n",
    "tokenizer.decode(tokenized_input[\"input_ids\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**这意味着我们需要对标签 做一些处理。 因为 tokenizer 返回的 id 比我们的数据集所包含的标签列表要长，其原因了是 单词被再次拆分<br> 或者 添加了一些特殊的标记 例如 CLS 和 SEP**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(31, 39)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(example[task+\"_tags\"]), len(tokenized_input[\"input_ids\"]) # ner_tags长度 和 input_ids 长度无法匹配"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**为此，我们可以使用 tokenized_input.word_ids()方法 来进行操作**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10, 11, 11, 11, 12, 13, 14, 15, 16, 17, 18, 18, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, None]\n",
      "39\n"
     ]
    }
   ],
   "source": [
    "print(tokenized_input.word_ids())\n",
    "print(len(tokenized_input.word_ids()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**正如我们所看到的，它返回一个列表，其中的元素数量与我们处理过的输入id相同，将特殊标记映射为None，将所有其他标记映射为各自的词。<br>这样，我们就可以将标签与处理后的输入id对齐。(其中 相同的数字表示 由同一个词 拆分而成的 子token )**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "39 39\n"
     ]
    }
   ],
   "source": [
    "word_ids = tokenized_input.word_ids()\n",
    "aligned_labels = [-100 if i is None else example[f\"{task}_tags\"][i] for i in word_ids]\n",
    "print(len(aligned_labels), len(tokenized_input[\"input_ids\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**在这里，我们将所有特殊标记的标签设置为-100（PyTorch所忽略的索引），将所有其他标记的标签设置为它们所来自的单词的标签。另一种策略是只对从一个给定的单词中获得的第一个标记设置标签，而对来自同一单词的其他子标记给予-100的标签。我们在此提出这两种策略，只需改变以下标志的值。**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "label_all_tokens = True  # True 是第一种策略  Fale 是第二种策略"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def tokenize_and_align_labels(examples):\n",
    "    tokenized_inputs = tokenizer(examples[\"tokens\"], truncation=True, is_split_into_words=True)\n",
    "\n",
    "    labels = []\n",
    "    for i, label in enumerate(examples[f\"{task}_tags\"]):\n",
    "        word_ids = tokenized_inputs.word_ids(batch_index=i)\n",
    "        previous_word_idx = None\n",
    "        label_ids = []\n",
    "        for word_idx in word_ids:\n",
    "\n",
    "            if word_idx is None:\n",
    "                label_ids.append(-100)\n",
    "\n",
    "            elif word_idx != previous_word_idx:\n",
    "                label_ids.append(label[word_idx])\n",
    "            # For the other tokens in a word, we set the label to either the current label or -100, depending on\n",
    "            # the label_all_tokens flag.\n",
    "            else:\n",
    "                label_ids.append(label[word_idx] if label_all_tokens else -100)\n",
    "            previous_word_idx = word_idx\n",
    "\n",
    "        labels.append(label_ids)\n",
    "\n",
    "    tokenized_inputs[\"labels\"] = labels\n",
    "    return tokenized_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [[101, 7327, 19164, 2446, 2655, 2000, 17757, 2329, 12559, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100]]}"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenize_and_align_labels(datasets['train'][:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS] eu rejects german call to boycott british lamb. [SEP]']"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.batch_decode(tokenize_and_align_labels(datasets['train'][:1])[\"input_ids\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**将数据集整体 进行token对齐操作    调用map（）**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
       "        num_rows: 14042\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
       "        num_rows: 3251\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
       "        num_rows: 3454\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e3e4e41d2d5345cb8cb80b94377013c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/15 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdc9b362a50d47728dc0f20ce76307cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bfe716db8181499cb66cc7fd6e56eb4a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 14042\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 3251\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 3454\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100]"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets[\"train\"][\"labels\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[3, 0, 7, 0, 0, 0, 7, 0, 0]"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets[\"train\"][\"ner_tags\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**进行模型Fine-tuning**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0fd24a466244ab592bce45fa025e9c7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight']\n",
      "- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer\n",
    "label_list = datasets[\"train\"].features[f\"{task}_tags\"].feature.names\n",
    "model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model_name = model_checkpoint.split(\"/\")[-1]\n",
    "args = TrainingArguments(\n",
    "    f\"{model_name}-finetuned-{task}\",\n",
    "    evaluation_strategy = \"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.01,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForTokenClassification\n",
    "\n",
    "data_collator = DataCollatorForTokenClassification(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "Looking in indexes: https://mirrors.ustc.edu.cn/pypi/web/simple\n",
      "\u001b[33mWARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<pip._vendor.urllib3.connection.HTTPSConnection object at 0x7fc0e6d54cd0>: Failed to establish a new connection: [Errno -2] Name or service not known')': /pypi/web/simple/seqeval/\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Retrying (Retry(total=3, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<pip._vendor.urllib3.connection.HTTPSConnection object at 0x7fc0e6d6e190>: Failed to establish a new connection: [Errno 101] Network is unreachable')': /pypi/web/simple/seqeval/\u001b[0m\u001b[33m\n",
      "\u001b[0mCollecting seqeval\n",
      "  Downloading https://mirrors.bfsu.edu.cn/pypi/web/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.6/43.6 KB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25hRequirement already satisfied: numpy>=1.14.0 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from seqeval) (1.19.5)\n",
      "Requirement already satisfied: scikit-learn>=0.21.3 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from seqeval) (1.0.2)\n",
      "Requirement already satisfied: scipy>=1.1.0 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval) (1.7.3)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval) (2.2.0)\n",
      "Requirement already satisfied: joblib>=0.11 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval) (1.1.0)\n",
      "Building wheels for collected packages: seqeval\n",
      "  Building wheel for seqeval (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16180 sha256=c2ceaaad2863968f5428daced7949c6ef089729d9617dc5ed07bd6a4893a8ac6\n",
      "  Stored in directory: /home/zutnlp/.cache/pip/wheels/11/f9/5f/edc55bc2839444a3a60c455e3a9e75879a3e489c06fd92bdf2\n",
      "Successfully built seqeval\n",
      "Installing collected packages: seqeval\n",
      "Successfully installed seqeval-1.2.2\n",
      "\u001b[33mWARNING: You are using pip version 22.0.3; however, version 22.1 is available.\n",
      "You should consider upgrading via the '/home/zutnlp/miniconda3/envs/liuyu/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "! pip install  seqeval\n",
    "metric = load_metric(\"seqeval\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**评估算法**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def compute_metrics(p):\n",
    "    predictions, labels = p\n",
    "    predictions = np.argmax(predictions, axis=2)\n",
    "\n",
    "    # 删除忽略的索引(特殊令牌)\n",
    "    true_predictions = [\n",
    "        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
    "        for prediction, label in zip(predictions, labels)\n",
    "    ]\n",
    "    true_labels = [\n",
    "        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
    "        for prediction, label in zip(predictions, labels)\n",
    "    ]\n",
    "\n",
    "    results = metric.compute(predictions=true_predictions, references=true_labels)\n",
    "    return {\n",
    "        \"precision\": results[\"overall_precision\"],\n",
    "        \"recall\": results[\"overall_recall\"],\n",
    "        \"f1\": results[\"overall_f1\"],\n",
    "        \"accuracy\": results[\"overall_accuracy\"],\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model,\n",
    "    args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the training set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.\n",
      "/home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n",
      "***** Running training *****\n",
      "  Num examples = 14042\n",
      "  Num Epochs = 3\n",
      "  Instantaneous batch size per device = 16\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 2634\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='2' max='2634' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [   2/2634 : < :, Epoch 0.00/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-500\n",
      "Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/config.json\n",
      "Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/special_tokens_map.json\n",
      "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 3251\n",
      "  Batch size = 16\n",
      "Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-1000\n",
      "Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/config.json\n",
      "Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/special_tokens_map.json\n",
      "Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-1500\n",
      "Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-1500/config.json\n",
      "Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-1500/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-1500/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-1500/special_tokens_map.json\n",
      "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 3251\n",
      "  Batch size = 16\n",
      "Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-2000\n",
      "Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-2000/config.json\n",
      "Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-2000/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-2000/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-2000/special_tokens_map.json\n",
      "Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-2500\n",
      "Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-2500/config.json\n",
      "Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-2500/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-2500/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-2500/special_tokens_map.json\n",
      "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 3251\n",
      "  Batch size = 16\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=2634, training_loss=0.08670986667218857, metrics={'train_runtime': 96.0161, 'train_samples_per_second': 438.739, 'train_steps_per_second': 27.433, 'total_flos': 510309848641824.0, 'train_loss': 0.08670986667218857, 'epoch': 3.0})"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 3251\n",
      "  Batch size = 16\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1' max='204' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [  1/204 : < :]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.061025530099868774,\n",
       " 'eval_precision': 0.9237063246351173,\n",
       " 'eval_recall': 0.9345564380803222,\n",
       " 'eval_f1': 0.9290997052772062,\n",
       " 'eval_accuracy': 0.9831127774159213,\n",
       " 'eval_runtime': 2.6055,\n",
       " 'eval_samples_per_second': 1247.758,\n",
       " 'eval_steps_per_second': 78.297,\n",
       " 'epoch': 3.0}"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the test set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Prediction *****\n",
      "  Num examples = 3454\n",
      "  Batch size = 16\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'LOC': {'precision': 0.8881818181818182,\n",
       "  'recall': 0.9199623352165726,\n",
       "  'f1': 0.9037927844588344,\n",
       "  'number': 2124},\n",
       " 'MISC': {'precision': 0.7567567567567568,\n",
       "  'recall': 0.7309236947791165,\n",
       "  'f1': 0.7436159346271707,\n",
       "  'number': 996},\n",
       " 'ORG': {'precision': 0.8615443134271586,\n",
       "  'recall': 0.875193199381762,\n",
       "  'f1': 0.8683151236342725,\n",
       "  'number': 2588},\n",
       " 'PER': {'precision': 0.9602673598217601,\n",
       "  'recall': 0.9514348785871964,\n",
       "  'f1': 0.9558307152097579,\n",
       "  'number': 2718},\n",
       " 'overall_precision': 0.887906647807638,\n",
       " 'overall_recall': 0.8940185141229527,\n",
       " 'overall_f1': 0.8909520993494974,\n",
       " 'overall_accuracy': 0.974760030384642}"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions, labels, _ = trainer.predict(tokenized_datasets[\"test\"])\n",
    "predictions = np.argmax(predictions, axis=2)\n",
    "\n",
    "# Remove ignored index (special tokens)\n",
    "true_predictions = [\n",
    "    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
    "    for prediction, label in zip(predictions, labels)\n",
    "]\n",
    "true_labels = [\n",
    "    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
    "    for prediction, label in zip(predictions, labels)\n",
    "]\n",
    "\n",
    "results = metric.compute(predictions=true_predictions, references=true_labels)\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
